class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        def find(x):
            if p[x] != x:
                p[x] = find(p[x])
            return p[x]

        def union(a, b):
            pa, pb = find(a), find(b)
            if pa != pb:
                size[pb] += size[pa]
                p[pa] = pb

        n = len(graph)
        p = list(range(n))
        size = [1] * n
        clean = [True] * n
        for i in initial:
            clean[i] = False
        for i in range(n):
            if not clean[i]:
                continue
            for j in range(i + 1, n):
                if clean[j] and graph[i][j] == 1:
                    union(i, j)
        cnt = Counter()
        mp = {}
        for i in initial:
            s = {find(j) for j in range(n) if clean[j] and graph[i][j] == 1}
            for root in s:
                cnt[root] += 1
            mp[i] = s

        mx, ans = -1, 0
        for i, s in mp.items():
            t = sum(size[root] for root in s if cnt[root] == 1)
            if mx < t or mx == t and i < ans:
                mx, ans = t, i
        return ans
